[DAGMM] DAGMM implementation of arrhythmia data set

dagmm
Author

kione kim

Published

October 19, 2023

Deep Autoencoding Gaussian Mixture Model for Arrhythmia dataset

imports

import torch
from torch import nn
import numpy as np
import pandas as pd
import argparse
import sys

data set

file_path = 'C:\\Users\\UOS\\Desktop\\연구\\5. 데이터\\data\\arrhythmia\\arrhythmia.data'

df = pd.read_csv(file_path, header=None)
df = df.replace('?', 0)
df = df.astype('float64')

data_array = df.values
data_array = torch.autograd.Variable(torch.from_numpy(data_array).float())
data_array.shape
torch.Size([452, 280])

argparse

parser = argparse.ArgumentParser(description='parser for argparse test')

parser.add_argument('--input_dim', type=int, default=data_array.shape[-1])
parser.add_argument('--enc_hidden_dim', type=str, default='10,2')
parser.add_argument('--dec_hidden_dim', type=str, default='10')
parser.add_argument('--est_hidden_dim', type=str, default='4, 10, 2')
parser.add_argument('--dropout', action='store_true', default=0.5)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--num_epoch', type=int, default=10)

if 'ipykernel_launcher' in sys.argv[0]:
    sys.argv = [sys.argv[0]]  

args = parser.parse_args()

enc_hidden_dim = args.enc_hidden_dim.split(',')
dec_hidden_dim = args.dec_hidden_dim.split(',')
est_hidden_dim = args.est_hidden_dim.split(',')

args.enc_hidden_dim_list = []
args.dec_hidden_dim_list = []
args.est_hidden_dim_list = []

args.enc_hidden_dim_list.append(args.input_dim)

for i in enc_hidden_dim:
    args.enc_hidden_dim_list.append(int(i))

args.enc_hidden_dim_list

args.dec_hidden_dim_list.append(args.enc_hidden_dim_list[-1])

for i in dec_hidden_dim:
    args.dec_hidden_dim_list.append(int(i))

args.dec_hidden_dim_list.append(args.input_dim)

args.dec_hidden_dim_list

for i in est_hidden_dim:
    args.est_hidden_dim_list.append(int(i))

args.est_hidden_dim_list

args
Namespace(input_dim=280, enc_hidden_dim='10,2', dec_hidden_dim='10', est_hidden_dim='4, 10, 2', dropout=0.5, learning_rate=0.001, num_epoch=10, enc_hidden_dim_list=[280, 10, 2], dec_hidden_dim_list=[2, 10, 280], est_hidden_dim_list=[4, 10, 2])

Compresssion network

class midlayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(midlayer, self).__init__()
        self.fc_layer   = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.Tanh()
    
    def forward(self, input):
        out = self.fc_layer(input)        
        out = self.activation(out)
        return out


class Encoder(nn.Module):
    def __init__(self, hidden_dim_list):
        super(Encoder, self).__init__()
        
        layer_list = []
        for i in range(len(hidden_dim_list)-2):
            layer_list.append(midlayer(hidden_dim_list[i], hidden_dim_list[i+1]))
        
        layer_list.append(nn.Linear(hidden_dim_list[i+1], hidden_dim_list[i+2]))
        self.layer = nn.Sequential(*layer_list)

    def forward(self, input):
        out = self.layer(input)
        return out
    
class Decoder(nn.Module):
    def __init__(self, hidden_dim_list):
        super(Decoder, self).__init__()

        layer_list = []
        for i in range(len(hidden_dim_list)-2):
            layer_list.append(midlayer(hidden_dim_list[i], hidden_dim_list[i+1]))
        
        layer_list.append(midlayer(hidden_dim_list[i+1], hidden_dim_list[i+2]))
        self.layer = nn.Sequential(*layer_list)
    
    def forward(self, input):
        out = self.layer(input)
        return out

class CompressionNet(nn.Module):
    def __init__(self, enc_hidden_dim_list, dec_hidden_dim_list):
        super().__init__()
        self.encoder = Encoder(enc_hidden_dim_list)
        self.decoder = Decoder(dec_hidden_dim_list)

        self._reconstruction_loss = nn.MSELoss()

    def forward(self, input):
        out = self.encoder(input)
        out = self.decoder(out)
        return out

    def encode(self, input):
        return self.encoder(input)

    def decode(self, input):
        return self.decoder(input)

    def reconstuction_loss(self, input, input_target):
        target_hat = self(input)
        return self._reconstruction_loss(target_hat, input_target)

reconstructed error

eps = torch.autograd.Variable(torch.FloatTensor([1.e-8]), requires_grad=False)

def relative_euclidean_distance(x1, x2, eps=eps):
    num = torch.norm(x1 - x2, p=2, dim=1)
    denom = torch.norm(x1, p=2, dim=1)
    return num / torch.max(denom, eps)

def cosine_similarity(x1, x2, eps=eps):
    dot_prod = torch.sum(x1 * x2, dim=1)
    dist_x1 = torch.norm(x1, p=2, dim=1)
    dist_x2 = torch.norm(x2, p=2, dim=1)
    return dot_prod / torch.max(dist_x1*dist_x2, eps)

Estimation network

class Estimation(nn.Module):
    def __init__(self, est_hidden_dim_list):
        super().__init__()
        
        layer_list = []
        for i in range(len(est_hidden_dim_list)-2):
            layer_list.append(midlayer(est_hidden_dim_list[i], est_hidden_dim_list[i+1]))
        
        layer_list.append(nn.Dropout(p=0.5))
        layer_list.append(nn.Linear(est_hidden_dim_list[-2], est_hidden_dim_list[-1]))
        layer_list.append(nn.Softmax())
        self.net = nn.Sequential(*layer_list)
        
    def forward(self, input):
        out = self.net(input)
        return out

Mixture

class Mixture(nn.Module):
    def __init__(self, latent_dimension):
        super().__init__()
        self.latent_dimension = latent_dimension

        self.Phi    = np.random.random([1])
        self.Phi    = torch.from_numpy(self.Phi).float()
        self.Phi    = nn.Parameter(self.Phi, requires_grad = False)

        self.mu     = 2.*np.random.random([latent_dimension]) - 0.5
        self.mu     = torch.from_numpy(self.mu).float()
        self.mu     = nn.Parameter(self.mu, requires_grad = False)

        self.Sigma  = np.eye(latent_dimension, latent_dimension)
        self.Sigma  = torch.from_numpy(self.Sigma).float()
        self.Sigma  = nn.Parameter(self.Sigma, requires_grad = False)
        
        self.eps_Sigma  = torch.FloatTensor(np.diag([1.e-8 for _ in range(latent_dimension)]))

    def forward(self, est_inputs, with_log = True):
        batch_size, _   = est_inputs.shape
        out_values  = []
        inv_sigma   = torch.inverse(self.Sigma)
        det_sigma   = np.linalg.det(self.Sigma.data.cpu().numpy())
        det_sigma   = torch.from_numpy(det_sigma.reshape([1])).float()
        det_sigma   = torch.autograd.Variable(det_sigma)
        for est_input in est_inputs:
            diff    = (est_input - self.mu).view(-1,1)
            out     = -0.5 * torch.mm(torch.mm(diff.view(1,-1), inv_sigma), diff)
            out     = (self.Phi * torch.exp(out)) / torch.sqrt(2. * np.pi * det_sigma)
            if with_log:
                out = -torch.log(out)
            out_values.append(float(out.data.cpu().numpy()))

        out = torch.autograd.Variable(torch.FloatTensor(out_values))
        return out
    
    def _update_parameters(self, samples, affiliations):
        if not self.training:
            return

        batch_size, _ = samples.shape

        # Updating phi.
        phi = torch.mean(affiliations)
        self.Phi.data = phi.data

        # Updating mu.
        num = 0.
        for i in range(batch_size):
            z_i     = samples[i, :]
            gamma_i = affiliations[i]
            num     += gamma_i * z_i
        
        denom        = torch.sum(affiliations)
        self.mu.data = (num / denom).data

        # Updating Sigma.
        mu  = self.mu
        num = None
        for i in range(batch_size):
            z_i     = samples[i, :]
            gamma_i = affiliations[i]
            diff    = (z_i - mu).view(-1, 1)
            to_add  = gamma_i * torch.mm(diff, diff.view(1, -1))
            if num is None:
                num = to_add
            else:
                num += to_add

        denom           = torch.sum(affiliations)
        self.Sigma.data = (num / denom).data + self.eps_Sigma

Gaussian Mixture Model

class GMM(nn.Module):
    def __init__(self, num_mixtures, latent_dimension):
        super().__init__()
        self.num_mixtures       = num_mixtures
        self.latent_dimension   = latent_dimension

        mixtures        = [Mixture(latent_dimension) for _ in range(num_mixtures)]
        self.mixtures   = nn.ModuleList(mixtures)
    
    def forward(self, est_inputs):
        out = None
        for mixture in self.mixtures:
            to_add  = mixture(est_inputs, with_log = False)
            if out is None:
                out = to_add
            else:
                out += to_add
        return -torch.log(out)
    
    def _update_mixtures_parameters(self, samples, mixtures_affiliations):
        if not self.training:
            return

        for i, mixture in enumerate(self.mixtures):
            affiliations = mixtures_affiliations[:, i]
            mixture._update_parameters(samples, affiliations)

model

class DAGMM(nn.Module):
    def __init__(self, compression_module, estimation_module, gmm_module):
        super().__init__()

        self.compressor = compression_module
        self.estimator  = estimation_module
        self.gmm        = gmm_module

    def forward(self, input):
        encoded = self.compressor.encode(input)
        decoded = self.compressor.decode(encoded)

        relative_ed     = relative_euclidean_distance(input, decoded)
        cosine_sim      = cosine_similarity(input, decoded)

        relative_ed     = relative_ed.view(-1, 1)
        cosine_sim      = relative_ed.view(-1, 1)
        latent_vectors  = torch.cat([encoded, relative_ed, cosine_sim], dim=1)

        if self.training:
            mixtures_affiliations = self.estimator(latent_vectors)
            self.gmm._update_mixtures_parameters(latent_vectors,
                                                 mixtures_affiliations)
        return self.gmm(latent_vectors)


class DAGMMArrhythmia(DAGMM):
    def __init__(self, enc_hidden_dim_list, dec_hidden_dim_list, est_hidden_dim_list):
        compressor  = CompressionNet(enc_hidden_dim_list, dec_hidden_dim_list)
        estimator   = Estimation(est_hidden_dim_list)
        gmm = GMM(num_mixtures = est_hidden_dim_list[-1], latent_dimension = enc_hidden_dim_list[-1] + 2)

        super().__init__(compression_module = compressor,
                         estimation_module  = estimator,
                         gmm_module         = gmm)

tests

def test_dagmm():
    net = DAGMMArrhythmia(args.enc_hidden_dim_list, args.dec_hidden_dim_list, args.est_hidden_dim_list)
    out = net(data_array)
    print(out)

def convert_to_var(input):
    out = torch.from_numpy(input).float()
    out = torch.autograd.Variable(out)
    return out

def test_update_mixture():
    batch_size       = 5
    latent_dimension = 7
    mix              = Mixture(latent_dimension)
    latent_vectors   = np.random.random([batch_size, latent_dimension])
    affiliations     = np.random.random([batch_size])
    latent_vectors   = convert_to_var(latent_vectors)
    affiliations     = convert_to_var(affiliations)

    for param in mix.parameters():
        print(param)

    mix.train()
    mix._update_parameters(latent_vectors, affiliations)

    for param in mix.parameters():
        print(param)


def test_forward_mixture():
    batch_size       = 5
    latent_dimension = 7

    mix = Mixture(latent_dimension)
    latent_vectors   = np.random.random([batch_size, latent_dimension])
    latent_vectors   = convert_to_var(latent_vectors)

    mix.train()
    out = mix(latent_vectors)
    print(out)


def test_update_gmm():
    batch_size      = int(5)
    latent_dimension= 7
    num_mixtures    = 2

    gmm = GMM(num_mixtures, latent_dimension)

    latent_vectors  = np.random.random([batch_size, latent_dimension])
    latent_vectors  = convert_to_var(latent_vectors)

    affiliations    = np.random.random([batch_size, num_mixtures])
    affiliations    = convert_to_var(affiliations)

    for param in gmm.parameters():
        print(param)

    gmm.train()
    gmm._update_mixtures_parameters(latent_vectors, affiliations)

    for param in gmm.parameters():
        print(param)
if __name__ == '__main__':
    test_update_mixture()
    test_forward_mixture()
    test_update_gmm()
    test_dagmm()
Parameter containing:
tensor([0.4108])
Parameter containing:
tensor([-0.4930, -0.0609,  0.9678,  0.9646,  1.2854,  1.0585,  1.3091])
Parameter containing:
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
Parameter containing:
tensor(0.5973)
Parameter containing:
tensor([0.5404, 0.8328, 0.5408, 0.5690, 0.4687, 0.4590, 0.4463])
Parameter containing:
tensor([[ 0.1052,  0.0173,  0.0185,  0.0464,  0.0382, -0.0627,  0.0843],
        [ 0.0173,  0.0189,  0.0177,  0.0177,  0.0304, -0.0166, -0.0027],
        [ 0.0185,  0.0177,  0.0652,  0.0079,  0.0596,  0.0047,  0.0050],
        [ 0.0464,  0.0177,  0.0079,  0.0314,  0.0209, -0.0305,  0.0156],
        [ 0.0382,  0.0304,  0.0596,  0.0209,  0.0789, -0.0282,  0.0273],
        [-0.0627, -0.0166,  0.0047, -0.0305, -0.0282,  0.0593, -0.0604],
        [ 0.0843, -0.0027,  0.0050,  0.0156,  0.0273, -0.0604,  0.1217]])
tensor([2.7962, 2.6436, 3.1959, 2.9833, 3.1488])
Parameter containing:
tensor([0.5244])
Parameter containing:
tensor([0.4938, 0.6227, 0.2106, 1.2189, 0.7241, 0.8534, 0.2676])
Parameter containing:
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
Parameter containing:
tensor([0.6656])
Parameter containing:
tensor([0.3613, 0.0844, 0.4508, 0.3979, 0.4715, 0.0697, 0.0790])
Parameter containing:
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
Parameter containing:
tensor(0.3362)
Parameter containing:
tensor([0.5625, 0.1650, 0.5926, 0.5153, 0.7716, 0.5900, 0.4792])
Parameter containing:
tensor([[ 0.0282,  0.0023,  0.0325, -0.0054,  0.0021, -0.0006,  0.0320],
        [ 0.0023,  0.0092, -0.0011,  0.0094, -0.0057, -0.0117,  0.0225],
        [ 0.0325, -0.0011,  0.0922, -0.0192,  0.0473,  0.0083,  0.0474],
        [-0.0054,  0.0094, -0.0192,  0.0462, -0.0479,  0.0173,  0.0209],
        [ 0.0021, -0.0057,  0.0473, -0.0479,  0.0729, -0.0206, -0.0045],
        [-0.0006, -0.0117,  0.0083,  0.0173, -0.0206,  0.0437, -0.0172],
        [ 0.0320,  0.0225,  0.0474,  0.0209, -0.0045, -0.0172,  0.0892]])
Parameter containing:
tensor(0.3915)
Parameter containing:
tensor([0.5520, 0.1997, 0.4993, 0.4864, 0.7780, 0.4641, 0.5020])
Parameter containing:
tensor([[ 0.0396,  0.0077,  0.0342,  0.0191, -0.0283,  0.0112,  0.0568],
        [ 0.0077,  0.0078,  0.0046,  0.0101, -0.0088, -0.0063,  0.0251],
        [ 0.0342,  0.0046,  0.1013,  0.0129,  0.0225,  0.0277,  0.0719],
        [ 0.0191,  0.0101,  0.0129,  0.0373, -0.0402,  0.0164,  0.0465],
        [-0.0283, -0.0088,  0.0225, -0.0402,  0.0746, -0.0143, -0.0359],
        [ 0.0112, -0.0063,  0.0277,  0.0164, -0.0143,  0.0364,  0.0081],
        [ 0.0568,  0.0251,  0.0719,  0.0465, -0.0359,  0.0081,  0.1249]])
tensor([-19.2564, -18.0547, -16.3949, -17.3645, -18.7643, -17.8371, -18.9319,
        -17.3849, -19.0331, -18.5996, -18.8663, -19.0285, -19.3665, -17.3287,
        -18.7925, -17.9085, -18.3725, -16.8432, -19.0094, -19.0351, -17.5516,
        -18.6999, -19.0101, -19.0587, -18.6902, -16.6082, -17.8009, -19.3211,
        -16.4337, -14.4446, -17.7824, -18.5536, -19.2948, -18.8873, -18.1048,
        -18.8613, -18.8028, -17.7254, -17.8467, -17.8801, -17.5114, -18.7899,
        -18.7854, -18.4335, -17.9540, -18.5277, -18.4958, -19.1059, -18.6475,
        -19.0338, -19.3081, -18.6593, -17.3205, -17.4425, -16.4498, -19.1613,
        -18.8155, -19.0773, -18.7134, -18.9291, -18.0730, -18.4230, -18.8852,
        -18.7261, -18.7798, -18.9858, -17.5758, -18.7625, -18.0510, -15.8704,
        -19.0205, -19.0118, -19.3139, -19.0564, -17.2290, -17.8323, -19.1797,
        -18.7321, -18.2683, -17.9285, -18.8088, -17.6969, -19.0502, -18.4027,
        -14.0781, -16.8003, -19.1769, -18.2311, -18.4472, -16.0477, -19.0438,
        -17.5159, -18.0267, -18.5224, -18.4645, -19.3172, -18.2687, -18.6210,
        -18.1220, -18.4591, -18.8159, -18.5173, -17.5662, -18.7601, -19.2286,
        -19.0618, -16.1031, -17.5881, -17.3610, -18.7824, -17.7521, -18.1415,
        -18.9739, -18.4967, -18.8139, -17.4159, -19.1722, -18.6194, -18.9447,
        -18.0765, -19.2638, -17.8369, -18.0603, -17.4801, -17.9954, -17.5410,
        -18.0189, -17.3560, -18.1195, -18.9525, -19.0643, -18.6243, -17.9777,
        -17.0903, -19.3566, -18.5979, -18.8478, -18.3555, -17.4501, -18.2208,
        -18.7513, -15.5237, -19.0146, -18.4070, -19.0495, -18.2230, -17.8902,
        -18.4425, -18.8322, -17.8915, -18.8612, -17.7105, -19.0856, -19.4493,
        -18.6984, -18.7255, -19.0555, -17.6645, -18.8864, -18.5403, -18.1607,
        -17.9466, -18.6024, -17.6435, -18.2158, -19.3479, -19.2290, -18.5409,
        -18.8731, -19.0171, -19.3246, -18.8465, -18.2816, -19.3152, -18.5592,
        -18.4677, -18.3102, -18.8956, -18.7767, -17.8755, -17.4399, -19.0722,
        -18.6370, -17.9408, -17.2431, -18.2954, -18.1857, -16.6689, -17.5154,
        -17.0577, -17.2580, -17.3108, -19.0950, -18.6613, -18.0495, -19.0558,
        -17.7375, -17.1572, -17.4150, -18.8458, -18.6497, -17.9805, -18.1902,
        -18.7660, -15.8835, -18.7459, -18.6498, -17.8905, -15.9800, -17.3226,
        -19.4193, -17.6966, -18.1932, -15.3249, -17.1338, -19.2474, -16.9423,
        -16.6158, -15.0462, -18.6544, -18.8282, -18.0630, -17.2687, -18.8309,
        -19.2707, -18.4543, -18.1300, -17.7864, -18.6297, -18.4028, -18.8658,
        -18.9418, -18.6116, -18.9729, -17.0293, -15.1011, -18.5634, -15.8022,
        -19.0415, -15.4348, -17.6778, -16.9394, -19.3467, -17.6310, -18.2665,
        -19.1546, -19.2717, -18.1000, -18.9345, -19.3095, -18.3156, -17.0592,
        -16.4144, -15.8596, -17.2900, -17.6649, -17.5640, -17.1291, -17.5856,
        -18.8441, -17.9905, -18.2374, -19.0501, -17.4113, -18.6796, -17.8788,
        -17.7550, -17.0484, -18.1735, -18.8908, -17.7271, -19.3673, -18.7071,
        -19.2400, -17.4982, -18.7901, -19.0618, -19.2101, -18.9515, -17.7362,
        -18.8028, -18.1069, -18.6178, -18.1941, -17.9602, -17.4824, -18.9062,
        -19.2635, -17.8047, -19.0641, -18.8086,  -5.6501, -17.8593, -18.3549,
        -17.9025, -17.8254, -18.1989, -18.5610, -18.8534, -19.0492, -16.7777,
        -18.6564, -18.9140, -16.0198, -17.6024, -17.1364, -19.0579, -19.0956,
        -13.0102, -18.8278, -18.8491, -19.0167, -19.3264, -17.9205, -17.3035,
        -18.9889, -16.1662, -19.0933, -16.8775, -17.3989, -16.0942, -19.1201,
        -16.6062, -17.5932, -19.0607, -19.1193, -17.8199, -19.2134, -19.2459,
        -18.4750, -18.6372, -17.2914, -16.9354, -16.4155, -18.6404, -18.6237,
        -13.6548, -17.8241, -17.7731, -18.9976, -16.6893, -18.7639, -19.0192,
        -17.3781, -17.7284, -18.4769, -18.6544, -18.6396, -10.5233, -18.9755,
        -16.2146, -18.3850, -18.8836, -18.0061, -12.4405, -18.1728, -16.0776,
        -18.8860, -17.7030, -17.1750, -18.8864, -16.2724, -18.0929, -17.7909,
        -19.3795, -18.2792, -11.6489, -16.9437, -19.3996, -18.7899, -18.8189,
        -16.4978, -18.9292, -16.5404, -18.1338, -18.5574, -13.8486, -18.1189,
        -18.7550, -17.5262, -17.1936, -11.0568, -18.9244, -18.6801, -18.5698,
        -18.8630, -18.6425, -18.0110, -16.3563, -16.7062, -18.7585, -18.6717,
        -18.0642, -14.7988, -19.3447, -18.4847, -17.6401, -17.0492, -16.6872,
        -18.1093, -19.1730, -17.8896, -18.2686, -15.0559, -18.7812, -17.8376,
        -19.0735, -19.0174, -17.5638, -19.2497, -18.5173, -18.4271, -18.2227,
        -18.6454, -18.4256, -17.6640, -17.9564, -19.0723, -18.4326, -17.5340,
        -15.9750, -19.0241, -18.5688, -17.6339, -18.6477, -19.3596, -17.8437,
        -18.5472, -18.7933, -18.0868, -17.8253, -18.8675, -19.3256, -17.0121,
        -18.9690, -18.8272, -18.9232, -18.9247, -17.0758, -18.1514, -18.4496,
        -17.3666, -17.7035, -18.6327, -19.0636, -19.3950, -18.8716, -17.5528,
        -17.5704, -16.1035, -18.6533, -18.8194])
C:\Users\UOS\anaconda3\Lib\site-packages\torch\nn\modules\container.py:217: UserWarning:

Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.

Ref

  • {https://openreview.net/forum?id=BJJLHbb0-}